{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "JAX Colab CPU Test",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/google/jax/blob/main/tests/notebooks/colab_cpu.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WkadOyTDCAWD",
        "colab_type": "text"
      },
      "source": [
        "# JAX Colab CPU Test\n",
        "\n",
        "This notebook is meant to be run in a [Colab](http://colab.research.google.com) CPU runtime as a basic check for JAX updates."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_tKNrbqqBHwu",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 68
        },
        "outputId": "071fb360-ddf5-41ae-d772-acc08ec71d9b"
      },
      "source": [
        "import jax\n",
        "import jaxlib\n",
        "\n",
        "!cat /var/colab/hostname\n",
        "print(jax.__version__)\n",
        "print(jaxlib.__version__)"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "m-s-1p12yf76kgzz\n",
            "0.1.64\n",
            "0.1.45\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oqEG21rADO1F",
        "colab_type": "text"
      },
      "source": [
        "## Confirm Device"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "8BwzMYhKGQj6",
        "outputId": "f79a44e3-4303-494c-9288-a4e582bb34cb",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 68
        }
      },
      "source": [
        "from jaxlib import xla_extension\n",
        "import jax\n",
        "key = jax.random.PRNGKey(1701)\n",
        "arr = jax.random.normal(key, (1000,))\n",
        "device = arr.device()\n",
        "print(f\"JAX device type: {device}\")\n",
        "assert device.platform == \"cpu\", f\"unexpected JAX device type: {device.platform}\""
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:123: UserWarning: No GPU/TPU found, falling back to CPU.\n",
            "  warnings.warn('No GPU/TPU found, falling back to CPU.')\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "JAX device type: cpu:0\n"
          ]
        }
      ],
      "source": [
        "from jaxlib import xla_extension\n",
        "import jax\n",
        "key = jax.random.PRNGKey(1701)\n",
        "arr = jax.random.normal(key, (1000,))\n",
        "device = list(arr.devices())[0]\n",
        "print(f\"JAX device type: {device}\")\n",
        "assert device.platform == \"cpu\", f\"unexpected JAX device type: {device.platform}\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z0FUY9yUC4k1",
        "colab_type": "text"
      },
      "source": [
        "## Matrix Multiplication"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "eXn8GUl6CG5N",
        "outputId": "307aa669-76f1-4117-b62a-7acb2aee2c16",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "import jax\n",
        "import numpy as np\n",
        "\n",
        "# matrix multiplication on GPU\n",
        "key = jax.random.PRNGKey(0)\n",
        "x = jax.random.normal(key, (3000, 3000))\n",
        "result = jax.numpy.dot(x, x.T).mean()\n",
        "print(result)"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "1.0216691\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0zTA2Q19DW4G",
        "colab_type": "text"
      },
      "source": [
        "## Linear Algebra"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "uW9j84_UDYof",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 51
        },
        "outputId": "3dd5d7c0-9d47-4be1-c6f7-068b432b69f7"
      },
      "source": [
        "import jax.numpy as jnp\n",
        "import jax.random as rand\n",
        "\n",
        "N = 10\n",
        "M = 20\n",
        "key = rand.PRNGKey(1701)\n",
        "\n",
        "X = rand.normal(key, (N, M))\n",
        "u, s, vt = jnp.linalg.svd(X)\n",
        "assert u.shape == (N, N)\n",
        "assert vt.shape == (M, M)\n",
        "print(s)"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "[6.9178133 5.9580317 5.581113  4.506963  4.111582  3.973543  3.3307292\n",
            " 2.8664916 1.8229378 1.5478933]\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jCyKUn4-DCXn",
        "colab_type": "text"
      },
      "source": [
        "## XLA Compilation"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "2GOn_HhDPuEn",
        "outputId": "41a40dd9-3680-458d-cedd-81ebcc2ab06f",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 51
        }
      },
      "source": [
        "@jax.jit\n",
        "def selu(x, alpha=1.67, lmbda=1.05):\n",
        "  return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)\n",
        "x = jax.random.normal(key, (5000,))\n",
        "result = selu(x).block_until_ready()\n",
        "print(result)"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "[ 0.34676832 -0.7532232   1.7060695  ...  2.1208048  -0.42621925\n",
            "  0.13093236]\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}
